{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "from tensorflow.examples.tutorials.mnist import mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "learning_rate = 0.8\n",
    "training_step = 30000\n",
    "\n",
    "n_input = 784\n",
    "n_hidden =  500\n",
    "n_labels = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (<ipython-input-37-a63b15bf44f0>, line 10)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"<ipython-input-37-a63b15bf44f0>\"\u001b[0;36m, line \u001b[0;32m10\u001b[0m\n\u001b[0;31m    out = tf.matmul(out_hidden, weights) + biases\u001b[0m\n\u001b[0m      ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "def inference(x_input, reuse = True):\n",
    "    with tf.variable_scope(\"hidden\", reuse=reuse):\n",
    "        weights_hidden = tf.get_variable(\"weights\", [n_input, n_hidden], initializer=tf.random_normal_initializer(stddev= 0.1)) \n",
    "        biases_hidden = tf.get_variable(\"biases\", [n_hidden], initializer= tf.constant_initializer)\n",
    "        out_hidden = tf.nn.relu(tf.matmul(x_input, weights_hidden) + biases_hidden)\n",
    "    \n",
    "    with tf.variable_scope(\"out\", reuse=reuse):\n",
    "        weights = tf.get_variable(\"weights\", [n_hidden, n_labels], initializer=tf.random_normal_initializer(stddev=0.1))\n",
    "        biases = tf.get_variable(\"biases\", [n_labels], initializer=tf.constant_initializer)\n",
    "        out = tf.matmul(out_hidden, weights) + biases\n",
    "        \n",
    "    return out        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    x = tf.placeholder(\"float\", [None, n_input])\n",
    "    y = tf.placeholder(\"float\", [None, n_labels])\n",
    "    pred = inference(x)\n",
    "    \n",
    "    # 损失函数\n",
    "    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels= y))\n",
    "    \n",
    "    # 优化\n",
    "    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)\n",
    "    \n",
    "    # 准确率计算\n",
    "    correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "    \n",
    "    init = tf.global_variables_initializer()\n",
    "    \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(init)\n",
    "        \n",
    "        # 验证集与测试集\n",
    "        validate_data = {x: mnist.validation.images, y: mnist.validation.labels}\n",
    "        test_data = {x: mnist.test.images, y: mnist.test.labels}\n",
    "        \n",
    "        for i in range(training_step):\n",
    "            \n",
    "            xs, ys = mnist.train.next_batch(batch_size)\n",
    "            _, loss = sess.run([optimizer, cross_entropy], feed_dict={x:xs, y:ys})\n",
    "            \n",
    "            if i % 1000 == 0:\n",
    "                validation_accuracy = sess.run(accuracy, feed_dict=validate_data)\n",
    "                print(\"after %d training steps, the loss is %g, the validation accuracy is %g\" % (i, loss, validate_accuracy))  \n",
    "        print(\"the training is finish!\")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./mnist_data_sets/train-images-idx3-ubyte.gz\n",
      "Extracting ./mnist_data_sets/train-labels-idx1-ubyte.gz\n",
      "Extracting ./mnist_data_sets/t10k-images-idx3-ubyte.gz\n",
      "Extracting ./mnist_data_sets/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'module' object has no attribute 'constant_initialize'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-39-739c3ca17917>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mmnist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"./mnist_data_sets/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-38-18420e8c69c8>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(mnist)\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplaceholder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"float\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_input\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplaceholder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"float\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_labels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0;31m# 损失函数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-34-42615580f265>\u001b[0m in \u001b[0;36minference\u001b[0;34m(x_input, reuse)\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvariable_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"out\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreuse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreuse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_variable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"weights\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mn_hidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_labels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitializer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom_normal_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstddev\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m         \u001b[0mbiases\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_variable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"biases\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mn_labels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitializer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstant_initialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_hidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mbiases\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'constant_initialize'"
     ]
    }
   ],
   "source": [
    "mnist = input_data.read_data_sets(\"./mnist_data_sets/\", one_hot=True)  \n",
    "train(mnist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
